import argparse
import torch
from datasets import load_dataset
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments, SchedulerType
from seqeval.metrics import f1_score, precision_score, recall_score
from transformers import get_linear_schedule_with_warmup
import numpy as np

from torch.optim import AdamW ,Adam, SGD
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from optimizers.lamb import create_lamb_optimizer
from optimizers.ALTO import create_ALTO_optimizer
from adabelief_pytorch import AdaBelief
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', default=3, type=int)
parser.add_argument('-b', '--batch-size', default=1024, type=int)
parser.add_argument('--lr', default=5e-5, type=float)
parser.add_argument('--optimizer', default='adam', choices=['sgd', 'adam', 'adamW', 'lamb', 'ALTO', 'adaBelief'])
parser.add_argument('--beta', default=0.999, type=float)
parser.add_argument('--workers', default=0, type=int)
args = parser.parse_args()

label_list = [
    "O",       # Outside of a named entity
    "B-MISC",  # Beginning of a miscellaneous entity right after another miscellaneous entity
    "I-MISC",  # Miscellaneous entity
    "B-PER",   # Beginning of a person's name right after another person's name
    "I-PER",   # Person's name
    "B-ORG",   # Beginning of an organization right after another organization
    "I-ORG",   # Organization
    "B-LOC",   # Beginning of a location right after another location
    "I-LOC"    # Location
]

def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=2)
    true_labels = [[label_list[l] for l in label if l != -100] for label in p.label_ids]
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, p.label_ids)
    ]

    f1 = f1_score(true_labels, true_predictions)
    precision = precision_score(true_labels, true_predictions)
    recall = recall_score(true_labels, true_predictions)

    return {"f1": f1, "precision": precision, "recall": recall}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")

dataset = load_dataset("conll2003.py")

model_dir = "....../bert-base-cased"

tokenizer = BertTokenizerFast.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir, num_labels=9).to(device)

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, padding="max_length", max_length=128, is_split_into_words=True
    )
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None or word_idx == previous_word_idx:
                label_ids.append(-100)
            else:
                label_ids.append(label[word_idx])
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# 对数据集进行标记化
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

# 定义训练参数
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=args.batch_size//torch.cuda.device_count(),
    per_device_eval_batch_size=args.batch_size//torch.cuda.device_count(),
    num_train_epochs=args.epochs,
    seed=42,
    load_best_model_at_end=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    fp16=True,  # Native AMP
    logging_dir='./logs',
    logging_steps=10,
    lr_scheduler_type=SchedulerType.LINEAR,
    dataloader_num_workers=args.workers//torch.cuda.device_count()
)

learning_rate = args.lr

if args.optimizer == 'sgd':
    optimizer = SGD(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adamW':
    optimizer = AdamW(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adaBelief':
    optimizer = AdaBelief(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
elif args.optimizer == 'ALTO':
    optimizer = create_ALTO_optimizer(model, lr=learning_rate, betas=(0.99, 0.9, 0.99), weight_decay=1e-4)
elif args.optimizer == 'lamb':
    optimizer = create_lamb_optimizer(model, lr=learning_rate, weight_decay=1e-4)
else:
    raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

num_training_steps = len(tokenized_datasets["train"]) // training_args.per_device_train_batch_size * training_args.num_train_epochs
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    optimizers=(optimizer, lr_scheduler),
    compute_metrics=compute_metrics
)

print("Start training...")
trainer.train()

evaluation_results = trainer.evaluate(tokenized_datasets["test"])
print("Evaluation results:", evaluation_results)
